{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67337d44",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import pyceres\n",
    "import pycolmap\n",
    "import numpy as np\n",
    "from hloc.utils import viz_3d"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc59c6d8",
   "metadata": {},
   "source": [
    "## Synthetic reconstruction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a155e6d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_reconstruction(num_points=50, num_images=2, seed=3, noise=0):\n",
    "    state = np.random.RandomState(seed)\n",
    "    rec = pycolmap.Reconstruction()\n",
    "    p3d = state.uniform(-1, 1, (num_points, 3)) + np.array([0, 0, 3])\n",
    "    for p in p3d:\n",
    "        rec.add_point3D(p, pycolmap.Track(), np.zeros(3))\n",
    "    w, h = 640, 480\n",
    "    cam = pycolmap.Camera(\n",
    "        model=\"SIMPLE_PINHOLE\",\n",
    "        width=w,\n",
    "        height=h,\n",
    "        params=np.array([max(w, h) * 1.2, w / 2, h / 2]),\n",
    "        camera_id=0,\n",
    "    )\n",
    "    rec.add_camera(cam)\n",
    "    for i in range(num_images):\n",
    "        im = pycolmap.Image(\n",
    "            id=i,\n",
    "            name=str(i),\n",
    "            camera_id=cam.camera_id,\n",
    "            cam_from_world=pycolmap.Rigid3d(\n",
    "                pycolmap.Rotation3d(), state.uniform(-1, 1, 3)\n",
    "            ),\n",
    "        )\n",
    "        im.registered = True\n",
    "        p2d = cam.img_from_cam(\n",
    "            im.cam_from_world * [p.xyz for p in rec.points3D.values()]\n",
    "        )\n",
    "        p2d_obs = np.array(p2d) + state.randn(len(p2d), 2) * noise\n",
    "        im.points2D = pycolmap.ListPoint2D(\n",
    "            [pycolmap.Point2D(p, id_) for p, id_ in zip(p2d_obs, rec.points3D)]\n",
    "        )\n",
    "        rec.add_image(im)\n",
    "    return rec\n",
    "\n",
    "\n",
    "rec_gt = create_reconstruction()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e69461f-3619-41a5-b202-c6dc8e2fc778",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = viz_3d.init_figure()\n",
    "viz_3d.plot_reconstruction(\n",
    "    fig, rec_gt, min_track_length=0, color=\"rgb(255,0,0)\", points_rgb=False\n",
    ")\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98a6db79",
   "metadata": {},
   "source": [
    "## Optimize 3D points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e199af72",
   "metadata": {},
   "outputs": [],
   "source": [
    "def define_problem(rec):\n",
    "    prob = pyceres.Problem()\n",
    "    loss = pyceres.TrivialLoss()\n",
    "    for im in rec.images.values():\n",
    "        cam = rec.cameras[im.camera_id]\n",
    "        for p in im.points2D:\n",
    "            cost = pycolmap.cost_functions.ReprojErrorCost(\n",
    "                cam.model, im.cam_from_world, p.xy\n",
    "            )\n",
    "            prob.add_residual_block(\n",
    "                cost, loss, [rec.points3D[p.point3D_id].xyz, cam.params]\n",
    "            )\n",
    "    for cam in rec.cameras.values():\n",
    "        prob.set_parameter_block_constant(cam.params)\n",
    "    return prob\n",
    "\n",
    "\n",
    "def solve(prob):\n",
    "    print(\n",
    "        prob.num_parameter_blocks(),\n",
    "        prob.num_parameters(),\n",
    "        prob.num_residual_blocks(),\n",
    "        prob.num_residuals(),\n",
    "    )\n",
    "    options = pyceres.SolverOptions()\n",
    "    options.linear_solver_type = pyceres.LinearSolverType.DENSE_QR\n",
    "    options.minimizer_progress_to_stdout = True\n",
    "    options.num_threads = -1\n",
    "    summary = pyceres.SolverSummary()\n",
    "    pyceres.solve(options, prob, summary)\n",
    "    print(summary.BriefReport())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac2305ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "rec = create_reconstruction()\n",
    "problem = define_problem(rec)\n",
    "solve(problem)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43df4cd1",
   "metadata": {},
   "source": [
    "Add some noise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "356a7ad2",
   "metadata": {},
   "outputs": [],
   "source": [
    "rec = create_reconstruction()\n",
    "for p in rec.points3D.values():\n",
    "    p.xyz += np.random.RandomState(0).uniform(-0.5, 0.5, 3)\n",
    "print(rec.points3D[1].xyz)\n",
    "problem = define_problem(rec)\n",
    "solve(problem)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e24927e",
   "metadata": {},
   "source": [
    "## Optimize poses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d9aca7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def define_problem2(rec):\n",
    "    prob = pyceres.Problem()\n",
    "    loss = pyceres.TrivialLoss()\n",
    "    for im in rec.images.values():\n",
    "        cam = rec.cameras[im.camera_id]\n",
    "        for p in im.points2D:\n",
    "            cost = pycolmap.cost_functions.ReprojErrorCost(cam.model, p.xy)\n",
    "            pose = im.cam_from_world\n",
    "            params = [\n",
    "                pose.rotation.quat,\n",
    "                pose.translation,\n",
    "                rec.points3D[p.point3D_id].xyz,\n",
    "                cam.params,\n",
    "            ]\n",
    "            prob.add_residual_block(cost, loss, params)\n",
    "        prob.set_manifold(\n",
    "            im.cam_from_world.rotation.quat, pyceres.EigenQuaternionManifold()\n",
    "        )\n",
    "    for cam in rec.cameras.values():\n",
    "        prob.set_parameter_block_constant(cam.params)\n",
    "    for p in rec.points3D.values():\n",
    "        prob.set_parameter_block_constant(p.xyz)\n",
    "    return prob\n",
    "\n",
    "\n",
    "rec = create_reconstruction()\n",
    "for im in rec.images.values():\n",
    "    im.cam_from_world.translation += np.random.randn(3) / 2\n",
    "    im.cam_from_world.rotation *= pycolmap.Rotation3d(np.random.randn(3) / 5)\n",
    "    im.cam_from_world.rotation.normalize()\n",
    "rec_init = rec.__deepcopy__({})\n",
    "init_from_gt = [\n",
    "    rec.images[i].cam_from_world * rec_gt.images[i].cam_from_world.inverse()\n",
    "    for i in rec.images\n",
    "]\n",
    "print([np.linalg.norm(t.translation) for t in init_from_gt])\n",
    "print([np.rad2deg(t.rotation.angle()) for t in init_from_gt])\n",
    "problem = define_problem2(rec)\n",
    "solve(problem)\n",
    "opt_from_gt = [\n",
    "    rec.images[i].cam_from_world * rec_gt.images[i].cam_from_world.inverse()\n",
    "    for i in rec.images\n",
    "]\n",
    "print([np.linalg.norm(t.translation) for t in opt_from_gt])\n",
    "print([np.rad2deg(t.rotation.angle()) for t in opt_from_gt])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31196ca9",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert np.allclose(rec.cameras[0].params, rec_gt.cameras[0].params)\n",
    "for i in rec.images:\n",
    "    print(\n",
    "        rec.images[i].cam_from_world.translation,\n",
    "        rec_gt.images[i].cam_from_world.translation,\n",
    "    )\n",
    "    print(\n",
    "        rec.images[i].cam_from_world.rotation.quat,\n",
    "        rec_gt.images[i].cam_from_world.rotation.quat,\n",
    "    )\n",
    "rec.points3D[1].xyz, rec_gt.points3D[1].xyz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aab4bcbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = viz_3d.init_figure()\n",
    "viz_3d.plot_reconstruction(\n",
    "    fig,\n",
    "    rec_init,\n",
    "    min_track_length=0,\n",
    "    color=\"rgb(255,255,255)\",\n",
    "    points_rgb=False,\n",
    "    name=\"init\",\n",
    ")\n",
    "viz_3d.plot_reconstruction(\n",
    "    fig, rec_gt, min_track_length=0, color=\"rgb(255,0,0)\", points_rgb=False, name=\"GT\"\n",
    ")\n",
    "viz_3d.plot_reconstruction(\n",
    "    fig,\n",
    "    rec,\n",
    "    min_track_length=0,\n",
    "    color=\"rgb(0,255,0)\",\n",
    "    points_rgb=False,\n",
    "    name=\"optimized\",\n",
    ")\n",
    "fig.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
